S$^{2}$FT: Efficient, Scalable and Generalizable LLM Fine-tuning by Structured Sparsity

  • 2024-12-19 18:47:54
  • Xinyu Yang, Jixuan Leng, Geyang Guo, Jiawei Zhao, Ryumei Nakada, Linjun Zhang, Huaxiu Yao, Beidi Chen
  • 0

Abstract

Current PEFT methods for LLMs can achieve either high quality, efficienttraining, or scalable serving, but not all three simultaneously. To addressthis limitation, we investigate sparse fine-tuning and observe a remarkableimprovement in generalization ability. Utilizing this key insight, we propose afamily of Structured Sparse Fine-Tuning (S$^{2}$FT) methods for LLMs, whichconcurrently achieve state-of-the-art fine-tuning performance, trainingefficiency, and inference scalability. S$^{2}$FT accomplishes this by"selecting sparsely and computing densely". It selects a few heads and channelsin the MHA and FFN modules for each Transformer block, respectively. Next, itco-permutes weight matrices on both sides of the coupled structures in LLMs toconnect the selected components in each layer into a dense submatrix. Finally,S$^{2}$FT performs in-place gradient updates on all submatrices. Throughtheoretical analysis and empirical results, our method prevents forgettingwhile simplifying optimization, delivers SOTA performance on both commonsenseand arithmetic reasoning with 4.6% and 1.3% average improvements compared toLoRA, and surpasses full FT by 11.5% when generalizing to various domains afterinstruction tuning. Using our partial backpropagation algorithm, S$^{2}$FTsaves training memory up to 3$\times$ and improves latency by 1.5-2.7$\times$compared to full FT, while delivering an average 10% improvement over LoRA onboth metrics. We further demonstrate that the weight updates in S$^{2}$FT canbe decoupled into adapters, enabling effective fusion, fast switch, andefficient parallelism for serving multiple fine-tuned models.